{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Federated Learning Training Plan: Create Plan\n",
    "\n",
    "This notebook is the 1st part of tutorial that demonstrates how to create Model-Centric Federated Learning process that trains a simple MNIST classifier model.\n",
    "\n",
    "This part will walk you through following steps:\n",
    "1. Defining the Model\n",
    "1. Defining the Training Plan (that runs on the client)\n",
    "1. Defining the Averaging Plan (that runs on the server)\n",
    "1. Hosting all created assets in the PyGrid\n",
    "\n",
    "Current list of problems:\n",
    " * `tensor.shape` is not traceable inside the Plan (issue [#3554](https://github.com/OpenMined/PySyft/issues/3554)).\n",
    " * Autograd/Plan tracing doesn't work with native torch's loss functions and optimizers.\n",
    " * others?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting up Sandbox...\n",
      "Done!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f76090127f0>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import syft as sy\n",
    "from syft.serde import protobuf\n",
    "from syft_proto.execution.v1.plan_pb2 import Plan as PlanPB\n",
    "from syft_proto.execution.v1.state_pb2 import State as StatePB\n",
    "from syft.grid.clients.model_centric_fl_client import ModelCentricFLClient\n",
    "from syft.execution.state import State\n",
    "from syft.execution.placeholder import PlaceHolder\n",
    "from syft.execution.translation import TranslationTarget\n",
    "\n",
    "import torch as th\n",
    "from torch import nn\n",
    "\n",
    "import os\n",
    "import websockets\n",
    "import json\n",
    "import requests\n",
    "\n",
    "sy.make_hook(globals())\n",
    "# force protobuf serialization for tensors\n",
    "hook.local_worker.framework = None\n",
    "th.random.manual_seed(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "This utility function will set tensors as model parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def set_model_params(module, params_list, start_param_idx=0):\n",
    "    \"\"\" Set params list into model recursively\n",
    "    \"\"\"\n",
    "    param_idx = start_param_idx\n",
    "\n",
    "    for name, param in module._parameters.items():\n",
    "        module._parameters[name] = params_list[param_idx]\n",
    "        param_idx += 1\n",
    "\n",
    "    for name, child in module._modules.items():\n",
    "        if child is not None:\n",
    "            param_idx = set_model_params(child, params_list, param_idx)\n",
    "\n",
    "    return param_idx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 1: Define the model\n",
    "\n",
    "This model will train on MNIST data, it's very simple yet can demonstrate learning process.\n",
    "There're 2 linear layers: \n",
    "\n",
    "* Linear 784x392\n",
    "* ReLU\n",
    "* Linear 392x10 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(784, 392)\n",
    "        self.fc2 = nn.Linear(392, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = nn.functional.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "model = Net()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 2: Define Training Plan\n",
    "### Loss function \n",
    "Batch size needs to be passed because otherwise `target.shape[0]` is not traced inside Plan yet (Issue [#3554](https://github.com/OpenMined/PySyft/issues/3554)).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def softmax_cross_entropy_with_logits(logits, targets, batch_size):\n",
    "    \"\"\" Calculates softmax entropy\n",
    "        Args:\n",
    "            * logits: (NxC) outputs of dense layer\n",
    "            * targets: (NxC) one-hot encoded labels\n",
    "            * batch_size: value of N, temporarily required because Plan cannot trace .shape\n",
    "    \"\"\"\n",
    "    # numstable logsoftmax\n",
    "    norm_logits = logits - logits.max()\n",
    "    log_probs = norm_logits - norm_logits.exp().sum(dim=1, keepdim=True).log()\n",
    "    # NLL, reduction = mean\n",
    "    return -(targets * log_probs).sum() / batch_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Optimization function\n",
    " \n",
    "Just updates weights with grad*lr.\n",
    "\n",
    "Note: can't do inplace update because of Autograd/Plan tracing specifics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def naive_sgd(param, **kwargs):\n",
    "    return param - kwargs['lr'] * param.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Training Plan procedure\n",
    "\n",
    "We define a routine that will take one batch of training data and model parameters,\n",
    "and will update model parameters to optimize them for given loss function using SGD."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "@sy.func2plan()\n",
    "def training_plan(X, y, batch_size, lr, model_params):\n",
    "    # inject params into model\n",
    "    set_model_params(model, model_params)\n",
    "\n",
    "    # forward pass\n",
    "    logits = model.forward(X)\n",
    "    \n",
    "    # loss\n",
    "    loss = softmax_cross_entropy_with_logits(logits, y, batch_size)\n",
    "\n",
    "    # backprop\n",
    "    loss.backward()\n",
    "\n",
    "    # step\n",
    "    updated_params = [\n",
    "        naive_sgd(param, lr=lr)\n",
    "        for param in model_params\n",
    "    ]\n",
    "    \n",
    "    # accuracy\n",
    "    pred = th.argmax(logits, dim=1)\n",
    "    target = th.argmax(y, dim=1)\n",
    "    acc = pred.eq(target).sum().float() / batch_size\n",
    "\n",
    "    return (\n",
    "        loss,\n",
    "        acc,\n",
    "        *updated_params\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Let's build this procedure into the Plan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Dummy input parameters to make the trace\n",
    "model_params = [param.data for param in model.parameters()]  # raw tensors instead of nn.Parameter\n",
    "X = th.randn(3, 28 * 28)\n",
    "y = nn.functional.one_hot(th.tensor([1, 2, 3]), 10)\n",
    "lr = th.tensor([0.01])\n",
    "batch_size = th.tensor([3.0])\n",
    "\n",
    "_ = training_plan.build(X, y, batch_size, lr, model_params, trace_autograd=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Let's look inside the Syft Plan and print out the list of operations recorded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):\n",
      "    var_0 = arg_5.t()\n",
      "    var_1 = arg_1.matmul(var_0)\n",
      "    var_2 = arg_6.add(var_1)\n",
      "    var_3 = var_2.relu()\n",
      "    var_4 = arg_7.t()\n",
      "    var_5 = var_3.matmul(var_4)\n",
      "    var_6 = arg_8.add(var_5)\n",
      "    var_7 = var_6.max()\n",
      "    var_8 = var_6.sub(var_7)\n",
      "    var_9 = var_8.exp()\n",
      "    var_10 = var_9.sum(dim=1, keepdim=True)\n",
      "    var_11 = var_10.log()\n",
      "    var_12 = var_8.sub(var_11)\n",
      "    var_13 = arg_2.mul(var_12)\n",
      "    var_14 = var_13.sum()\n",
      "    var_15 = var_14.neg()\n",
      "    out_1 = var_15.div(arg_3)\n",
      "    var_16 = out_1.mul(0)\n",
      "    var_17 = var_16.add(1)\n",
      "    var_18 = var_17.div(arg_3)\n",
      "    var_19 = var_18.mul(-1)\n",
      "    var_20 = var_19.reshape([-1, 1])\n",
      "    var_21 = var_13.mul(0)\n",
      "    var_22 = var_21.add(1)\n",
      "    var_23 = var_22.mul(var_20)\n",
      "    var_24 = var_23.mul(arg_2)\n",
      "    var_25 = var_24.add(0)\n",
      "    var_26 = var_24.mul(-1)\n",
      "    var_27 = var_26.sum(dim=[1], keepdim=True)\n",
      "    var_28 = var_25.add(0)\n",
      "    var_29 = var_28.add(0)\n",
      "    var_30 = var_28.add(0)\n",
      "    var_31 = var_29.sum(dim=[0])\n",
      "    var_32 = var_31.copy()\n",
      "    var_33 = var_4.t()\n",
      "    var_34 = var_30.matmul(var_33)\n",
      "    var_35 = var_3.t()\n",
      "    var_36 = var_35.matmul(var_30)\n",
      "    var_37 = var_2.mul(0)\n",
      "    var_38 = var_2.__gt__(var_37)\n",
      "    var_39 = var_38.mul(var_34)\n",
      "    var_40 = var_39.add(0)\n",
      "    var_41 = var_39.add(0)\n",
      "    var_42 = var_40.sum(dim=[0])\n",
      "    var_43 = var_42.copy()\n",
      "    var_44 = arg_1.t()\n",
      "    var_45 = var_44.matmul(var_41)\n",
      "    var_46 = var_45.t()\n",
      "    var_47 = var_46.copy()\n",
      "    var_48 = var_36.t()\n",
      "    var_49 = var_48.copy()\n",
      "    var_50 = var_10.__rtruediv__(1)\n",
      "    var_51 = var_27.mul(var_50)\n",
      "    var_52 = var_51.reshape([-1, 1])\n",
      "    var_53 = var_9.mul(0)\n",
      "    var_54 = var_53.add(1)\n",
      "    var_55 = var_54.mul(var_52)\n",
      "    var_56 = var_8.exp()\n",
      "    var_57 = var_55.mul(var_56)\n",
      "    var_58 = var_57.add(0)\n",
      "    var_59 = var_58.add(0)\n",
      "    var_60 = var_58.add(0)\n",
      "    var_61 = var_59.sum(dim=[0])\n",
      "    var_32 = var_32.add_(var_61)\n",
      "    var_62 = var_4.t()\n",
      "    var_63 = var_60.matmul(var_62)\n",
      "    var_64 = var_3.t()\n",
      "    var_65 = var_64.matmul(var_60)\n",
      "    var_66 = var_2.mul(0)\n",
      "    var_67 = var_2.__gt__(var_66)\n",
      "    var_68 = var_67.mul(var_63)\n",
      "    var_69 = var_68.add(0)\n",
      "    var_70 = var_68.add(0)\n",
      "    var_71 = var_69.sum(dim=[0])\n",
      "    var_43 = var_43.add_(var_71)\n",
      "    var_72 = arg_1.t()\n",
      "    var_73 = var_72.matmul(var_70)\n",
      "    var_74 = var_73.t()\n",
      "    var_47 = var_47.add_(var_74)\n",
      "    var_75 = var_65.t()\n",
      "    var_49 = var_49.add_(var_75)\n",
      "    var_76 = arg_4.mul(var_47)\n",
      "    out_3 = arg_5.sub(var_76)\n",
      "    var_77 = arg_4.mul(var_43)\n",
      "    out_4 = arg_6.sub(var_77)\n",
      "    var_78 = arg_4.mul(var_49)\n",
      "    out_5 = arg_7.sub(var_78)\n",
      "    var_79 = arg_4.mul(var_32)\n",
      "    out_6 = arg_8.sub(var_79)\n",
      "    var_80 = torch.argmax(var_6, dim=1)\n",
      "    var_81 = torch.argmax(arg_2, dim=1)\n",
      "    var_82 = var_80.eq(var_81)\n",
      "    var_83 = var_82.sum()\n",
      "    var_84 = var_83.float()\n",
      "    out_2 = var_84.div(arg_3)\n",
      "    return out_1, out_2, out_3, out_4, out_5, out_6\n"
     ]
    }
   ],
   "source": [
    "print(training_plan.code)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Plan should be automatically translated to torchscript and tensorflow.js, too.\n",
    "Let's examine torchscript code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def <Plan training_plan id:21834869296 owner:me built>\n",
      "(argument_0: Tensor,\n",
      "    argument_1: Tensor,\n",
      "    argument_2: Tensor,\n",
      "    argument_3: Tensor,\n",
      "    argument_4: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:\n",
      "  _0, _1, _2, _3, = argument_4\n",
      "  _4 = torch.add(_1, torch.matmul(argument_0, torch.t(_0)), alpha=1)\n",
      "  _5 = torch.relu(_4)\n",
      "  _6 = torch.t(_2)\n",
      "  _7 = torch.add(_3, torch.matmul(_5, _6), alpha=1)\n",
      "  _8 = torch.sub(_7, torch.max(_7), alpha=1)\n",
      "  _9 = torch.exp(_8)\n",
      "  _10 = torch.sum(_9, [1], True, dtype=None)\n",
      "  _11 = torch.sub(_8, torch.log(_10), alpha=1)\n",
      "  _12 = torch.mul(argument_1, _11)\n",
      "  _13 = torch.div(torch.neg(torch.sum(_12, dtype=None)), argument_2)\n",
      "  _14 = torch.add(torch.mul(_13, CONSTANTS.c0), CONSTANTS.c1, alpha=1)\n",
      "  _15 = torch.mul(torch.div(_14, argument_2), CONSTANTS.c2)\n",
      "  _16 = torch.reshape(_15, [-1, 1])\n",
      "  _17 = torch.add(torch.mul(_12, CONSTANTS.c0), CONSTANTS.c1, alpha=1)\n",
      "  _18 = torch.mul(torch.mul(_17, _16), argument_1)\n",
      "  _19 = torch.add(_18, CONSTANTS.c0, alpha=1)\n",
      "  _20 = torch.sum(torch.mul(_18, CONSTANTS.c2), [1], True, dtype=None)\n",
      "  _21 = torch.add(_19, CONSTANTS.c0, alpha=1)\n",
      "  _22 = torch.add(_21, CONSTANTS.c0, alpha=1)\n",
      "  _23 = torch.add(_21, CONSTANTS.c0, alpha=1)\n",
      "  _24 = torch.sum(_22, [0], False, dtype=None)\n",
      "  _25 = torch.matmul(_23, torch.t(_6))\n",
      "  _26 = torch.matmul(torch.t(_5), _23)\n",
      "  _27 = torch.gt(_4, torch.mul(_4, CONSTANTS.c0))\n",
      "  _28 = torch.mul(_27, _25)\n",
      "  _29 = torch.add(_28, CONSTANTS.c0, alpha=1)\n",
      "  _30 = torch.add(_28, CONSTANTS.c0, alpha=1)\n",
      "  _31 = torch.sum(_29, [0], False, dtype=None)\n",
      "  _32 = torch.matmul(torch.t(argument_0), _30)\n",
      "  _33 = torch.t(_32)\n",
      "  _34 = torch.t(_26)\n",
      "  _35 = torch.mul(torch.reciprocal(_10), CONSTANTS.c1)\n",
      "  _36 = torch.reshape(torch.mul(_20, _35), [-1, 1])\n",
      "  _37 = torch.add(torch.mul(_9, CONSTANTS.c0), CONSTANTS.c1, alpha=1)\n",
      "  _38 = torch.mul(torch.mul(_37, _36), torch.exp(_8))\n",
      "  _39 = torch.add(_38, CONSTANTS.c0, alpha=1)\n",
      "  _40 = torch.add(_39, CONSTANTS.c0, alpha=1)\n",
      "  _41 = torch.add(_39, CONSTANTS.c0, alpha=1)\n",
      "  _42 = torch.sum(_40, [0], False, dtype=None)\n",
      "  _43 = torch.add_(_24, _42, alpha=1)\n",
      "  _44 = torch.matmul(_41, torch.t(_6))\n",
      "  _45 = torch.matmul(torch.t(_5), _41)\n",
      "  _46 = torch.gt(_4, torch.mul(_4, CONSTANTS.c0))\n",
      "  _47 = torch.mul(_46, _44)\n",
      "  _48 = torch.add(_47, CONSTANTS.c0, alpha=1)\n",
      "  _49 = torch.add(_47, CONSTANTS.c0, alpha=1)\n",
      "  _50 = torch.sum(_48, [0], False, dtype=None)\n",
      "  _51 = torch.add_(_31, _50, alpha=1)\n",
      "  _52 = torch.matmul(torch.t(argument_0), _49)\n",
      "  _53 = torch.add_(_33, torch.t(_52), alpha=1)\n",
      "  _54 = torch.add_(_34, torch.t(_45), alpha=1)\n",
      "  _55 = torch.sub(_0, torch.mul(argument_3, _53), alpha=1)\n",
      "  _56 = torch.sub(_1, torch.mul(argument_3, _51), alpha=1)\n",
      "  _57 = torch.sub(_2, torch.mul(argument_3, _54), alpha=1)\n",
      "  _58 = torch.sub(_3, torch.mul(argument_3, _43), alpha=1)\n",
      "  _59 = torch.eq(torch.argmax(_7, 1, False), torch.argmax(argument_1, 1, False))\n",
      "  _60 = torch.to(torch.sum(_59, dtype=None), 6, False, False, None)\n",
      "  _61 = (_13, torch.div(_60, argument_2), _55, _56, _57, _58)\n",
      "  return _61\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(training_plan.torchscript.code)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Tensorflow.js code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):\n",
      "    var_0 = tf.transpose(arg_5)\n",
      "    var_1 = tf.matMul(arg_1, var_0)\n",
      "    var_2 = tf.add(arg_6, var_1)\n",
      "    var_3 = tf.relu(var_2)\n",
      "    var_4 = tf.transpose(arg_7)\n",
      "    var_5 = tf.matMul(var_3, var_4)\n",
      "    var_6 = tf.add(arg_8, var_5)\n",
      "    var_7 = tf.max(var_6)\n",
      "    var_8 = tf.sub(var_6, var_7)\n",
      "    var_9 = tf.exp(var_8)\n",
      "    var_10 = tf.sum(var_9, 1, keepdim=True)\n",
      "    var_11 = tf.log(var_10)\n",
      "    var_12 = tf.sub(var_8, var_11)\n",
      "    var_13 = tf.mul(arg_2, var_12)\n",
      "    var_14 = tf.sum(var_13)\n",
      "    var_15 = tf.neg(var_14)\n",
      "    out_1 = tf.div(var_15, arg_3)\n",
      "    var_16 = tf.mul(out_1, 0)\n",
      "    var_17 = tf.add(var_16, 1)\n",
      "    var_18 = tf.div(var_17, arg_3)\n",
      "    var_19 = tf.mul(var_18, -1)\n",
      "    var_20 = tf.layers.reshape(var_19, [-1, 1])\n",
      "    var_21 = tf.mul(var_13, 0)\n",
      "    var_22 = tf.add(var_21, 1)\n",
      "    var_23 = tf.mul(var_22, var_20)\n",
      "    var_24 = tf.mul(var_23, arg_2)\n",
      "    var_25 = tf.add(var_24, 0)\n",
      "    var_26 = tf.mul(var_24, -1)\n",
      "    var_27 = tf.sum(var_26, [1], keepdim=True)\n",
      "    var_28 = tf.add(var_25, 0)\n",
      "    var_29 = tf.add(var_28, 0)\n",
      "    var_30 = tf.add(var_28, 0)\n",
      "    var_31 = tf.sum(var_29, [0])\n",
      "    var_32 = clone(var_31)\n",
      "    var_33 = tf.transpose(var_4)\n",
      "    var_34 = tf.matMul(var_30, var_33)\n",
      "    var_35 = tf.transpose(var_3)\n",
      "    var_36 = tf.matMul(var_35, var_30)\n",
      "    var_37 = tf.mul(var_2, 0)\n",
      "    var_38 = tf.greater(var_2, var_37)\n",
      "    var_39 = tf.mul(var_38, var_34)\n",
      "    var_40 = tf.add(var_39, 0)\n",
      "    var_41 = tf.add(var_39, 0)\n",
      "    var_42 = tf.sum(var_40, [0])\n",
      "    var_43 = clone(var_42)\n",
      "    var_44 = tf.transpose(arg_1)\n",
      "    var_45 = tf.matMul(var_44, var_41)\n",
      "    var_46 = tf.transpose(var_45)\n",
      "    var_47 = clone(var_46)\n",
      "    var_48 = tf.transpose(var_36)\n",
      "    var_49 = clone(var_48)\n",
      "    var_50 = tf.div(1, var_10)\n",
      "    var_51 = tf.mul(var_27, var_50)\n",
      "    var_52 = tf.layers.reshape(var_51, [-1, 1])\n",
      "    var_53 = tf.mul(var_9, 0)\n",
      "    var_54 = tf.add(var_53, 1)\n",
      "    var_55 = tf.mul(var_54, var_52)\n",
      "    var_56 = tf.exp(var_8)\n",
      "    var_57 = tf.mul(var_55, var_56)\n",
      "    var_58 = tf.add(var_57, 0)\n",
      "    var_59 = tf.add(var_58, 0)\n",
      "    var_60 = tf.add(var_58, 0)\n",
      "    var_61 = tf.sum(var_59, [0])\n",
      "    var_32 = tf.add(var_32, var_61)\n",
      "    var_62 = tf.transpose(var_4)\n",
      "    var_63 = tf.matMul(var_60, var_62)\n",
      "    var_64 = tf.transpose(var_3)\n",
      "    var_65 = tf.matMul(var_64, var_60)\n",
      "    var_66 = tf.mul(var_2, 0)\n",
      "    var_67 = tf.greater(var_2, var_66)\n",
      "    var_68 = tf.mul(var_67, var_63)\n",
      "    var_69 = tf.add(var_68, 0)\n",
      "    var_70 = tf.add(var_68, 0)\n",
      "    var_71 = tf.sum(var_69, [0])\n",
      "    var_43 = tf.add(var_43, var_71)\n",
      "    var_72 = tf.transpose(arg_1)\n",
      "    var_73 = tf.matMul(var_72, var_70)\n",
      "    var_74 = tf.transpose(var_73)\n",
      "    var_47 = tf.add(var_47, var_74)\n",
      "    var_75 = tf.transpose(var_65)\n",
      "    var_49 = tf.add(var_49, var_75)\n",
      "    var_76 = tf.mul(arg_4, var_47)\n",
      "    out_3 = tf.sub(arg_5, var_76)\n",
      "    var_77 = tf.mul(arg_4, var_43)\n",
      "    out_4 = tf.sub(arg_6, var_77)\n",
      "    var_78 = tf.mul(arg_4, var_49)\n",
      "    out_5 = tf.sub(arg_7, var_78)\n",
      "    var_79 = tf.mul(arg_4, var_32)\n",
      "    out_6 = tf.sub(arg_8, var_79)\n",
      "    var_80 = tf.argMax(var_6, 1)\n",
      "    var_81 = tf.argMax(arg_2, 1)\n",
      "    var_82 = tf.equal(var_80, var_81)\n",
      "    var_83 = tf.sum(var_82)\n",
      "    var_84 = tf.cast(var_83, float32)\n",
      "    out_2 = tf.div(var_84, arg_3)\n",
      "    return out_1, out_2, out_3, out_4, out_5, out_6\n"
     ]
    }
   ],
   "source": [
    "training_plan.base_framework = TranslationTarget.TENSORFLOW_JS.value\n",
    "print(training_plan.code)\n",
    "training_plan.base_framework = TranslationTarget.PYTORCH.value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Define Averaging Plan\n",
    "\n",
    "Averaging Plan is executed by PyGrid at the end of the cycle,\n",
    "to average _diffs_ submitted by workers and update the model\n",
    "and create new checkpoint for the next cycle.\n",
    "\n",
    "_Diff_ is the difference between client-trained\n",
    "model params and original model params,\n",
    "so it has same number of tensors and tensor's shapes\n",
    "as the model parameters.\n",
    "\n",
    "We define Plan that processes one diff at a time.\n",
    "Such Plans require `iterative_plan` flag set to `True`\n",
    "in `server_config` when hosting FL model to PyGrid.\n",
    "\n",
    "Plan below will calculate simple mean of each parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "@sy.func2plan()\n",
    "def avg_plan(avg, item, num):\n",
    "    new_avg = []\n",
    "    for i, param in enumerate(avg):\n",
    "        new_avg.append((avg[i] * num + item[i]) / (num + 1))\n",
    "    return new_avg\n",
    "\n",
    "# Build the Plan\n",
    "_ = avg_plan.build(model_params, model_params, th.tensor([1.0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def avg_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8, arg_9):\n",
      "    var_0 = arg_1.__mul__(arg_9)\n",
      "    var_1 = var_0.__add__(arg_5)\n",
      "    var_2 = arg_9.__add__(1)\n",
      "    out_1 = var_1.__truediv__(var_2)\n",
      "    var_3 = arg_2.__mul__(arg_9)\n",
      "    var_4 = var_3.__add__(arg_6)\n",
      "    var_5 = arg_9.__add__(1)\n",
      "    out_2 = var_4.__truediv__(var_5)\n",
      "    var_6 = arg_3.__mul__(arg_9)\n",
      "    var_7 = var_6.__add__(arg_7)\n",
      "    var_8 = arg_9.__add__(1)\n",
      "    out_3 = var_7.__truediv__(var_8)\n",
      "    var_9 = arg_4.__mul__(arg_9)\n",
      "    var_10 = var_9.__add__(arg_8)\n",
      "    var_11 = arg_9.__add__(1)\n",
      "    out_4 = var_10.__truediv__(var_11)\n",
      "    return out_1, out_2, out_3, out_4\n"
     ]
    }
   ],
   "source": [
    "# Let's check Plan contents\n",
    "print(avg_plan.code)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Test averaging plan\n",
    "# Pretend there're diffs, all params of which are ones * dummy_coeffs\n",
    "dummy_coeffs = [1, 5.5, 7, 55]\n",
    "dummy_diffs = [[th.ones_like(param) * i for param in model_params] for i in dummy_coeffs]\n",
    "mean_coeff = th.tensor(dummy_coeffs).mean().item()\n",
    "\n",
    "# Remove original function to make sure we execute traced Plan\n",
    "avg_plan.forward = None\n",
    "\n",
    "# Calculate avg value using our plan\n",
    "avg = dummy_diffs[0]\n",
    "for i, diff in enumerate(dummy_diffs[1:]):\n",
    "    avg = avg_plan(list(avg), diff, th.tensor([i + 1]))\n",
    "\n",
    "# Avg should be ones*mean_coeff for each param\n",
    "for i, param in enumerate(model_params):\n",
    "    expected = th.ones_like(param) * mean_coeff\n",
    "    assert avg[i].eq(expected).all(), f\"param #{i}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 4: Host in PyGrid\n",
    "\n",
    "Let's now host everything in PyGrid so that it can be accessed by worker libraries (syft.js, KotlinSyft, SwiftSyft, or even PySyft itself)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Follow PyGrid [README](https://github.com/OpenMined/PyGrid/#getting-started) to start PyGrid Node. In the code below we assume that the PyGrid Node is running on `127.0.0.1`, port `5000`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Define name, version, configs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyGrid Node address\n",
    "gridAddress = \"127.0.0.1:5000\"\n",
    "grid = ModelCentricFLClient(id=\"test\", address=gridAddress, secure=False)\n",
    "grid.connect()# These name/version you use in worker\n",
    "name = \"mnist\"\n",
    "version = \"1.0.0\"\n",
    "\n",
    "client_config = {\n",
    "    \"name\": name,\n",
    "    \"version\": version,\n",
    "    \"batch_size\": 64,\n",
    "    \"lr\": 0.005,\n",
    "    \"max_updates\": 100  # custom syft.js option that limits number of training loops per worker\n",
    "}\n",
    "\n",
    "server_config = {\n",
    "    \"min_workers\": 5,\n",
    "    \"max_workers\": 5,\n",
    "    \"pool_selection\": \"random\",\n",
    "    \"do_not_reuse_workers_until_cycle\": 6,\n",
    "    \"cycle_length\": 28800,  # max cycle length in seconds\n",
    "    \"num_cycles\": 5,  # max number of cycles\n",
    "    \"max_diffs\": 1,  # number of diffs to collect before avg\n",
    "    \"minimum_upload_speed\": 0,\n",
    "    \"minimum_download_speed\": 0,\n",
    "    \"iterative_plan\": True  # tells PyGrid that avg plan is executed per diff\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Authentication (optional)\n",
    "Let's additionally protect the model with simple authentication for workers.\n",
    "\n",
    "PyGrid supports authentication via JWT token (HMAC, RSA) or opaque token\n",
    "via remote API.\n",
    "\n",
    "We'll try JWT/RSA. Suppose we generate RSA keys:\n",
    "```\n",
    "openssl genrsa -out private.pem\n",
    "openssl rsa -in private.pem -pubout -out public.pem\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "private_key = \"\"\"\n",
    "-----BEGIN RSA PRIVATE KEY-----\n",
    "MIIEowIBAAKCAQEAzQMcI09qonB9OZT20X3Z/oigSmybR2xfBQ1YJ1oSjQ3YgV+G\n",
    "FUuhEsGDgqt0rok9BreT4toHqniFixddncTHg7EJzU79KZelk2m9I2sEsKUqEsEF\n",
    "lMpkk9qkPHhJB5AQoClOijee7UNOF4yu3HYvGFphwwh4TNJXxkCg69/RsvPBIPi2\n",
    "9vXFQzFE7cbN6jSxiCtVrpt/w06jJUsEYgNVQhUFABDyWN4h/67M1eArGA540vyd\n",
    "kYdSIEQdknKHjPW62n4dvqDWxtnK0HyChsB+LzmjEnjTJqUzr7kM9Rzq3BY01DNi\n",
    "TVcB2G8t/jICL+TegMGU08ANMKiDfSMGtpz3ZQIDAQABAoIBAD+xbKeHv+BxxGYE\n",
    "Yt5ZFEYhGnOk5GU/RRIjwDSRplvOZmpjTBwHoCZcmsgZDqo/FwekNzzuch1DTnIV\n",
    "M0+V2EqQ0TPJC5xFcfqnikybrhxXZAfpkhtU+gR5lDb5Q+8mkhPAYZdNioG6PGPS\n",
    "oGz8BsuxINhgJEfxvbVpVNWTdun6hLOAMZaH3DHgi0uyTBg8ofARoZP5RIbHwW+D\n",
    "p+5vd9x/x7tByu76nd2UbMp3yqomlB5jQktqyilexCIknEnfb3i/9jqFv8qVE5P6\n",
    "e3jdYoJY+FoomWhqEvtfPpmUFTY5lx4EERCb1qhWG3a7sVBqTwO6jJJBsxy3RLIS\n",
    "Ic0qZcECgYEA6GsBP11a2T4InZ7cixd5qwSeznOFCzfDVvVNI8KUw+n4DOPndpao\n",
    "TUskWOpoV8MyiEGdQHgmTOgGaCXN7bC0ERembK0J64FI3TdKKg0v5nKa7xHb7Qcv\n",
    "t9ccrDZVn4y/Yk5PCqjNWTR3/wDR88XouzIGaWkGlili5IJqdLEvPvUCgYEA4dA+\n",
    "5MNEQmNFezyWs//FS6G3lTRWgjlWg2E6BXXvkEag6G5SBD31v3q9JIjs+sYdOmwj\n",
    "kfkQrxEtbs173xgYWzcDG1FI796LTlJ/YzuoKZml8vEF3T8C4Bkbl6qj9DZljb2j\n",
    "ehjTv5jA256sSUEqOa/mtNFUbFlBjgOZh3TCsLECgYAc701tdRLdXuK1tNRiIJ8O\n",
    "Enou26Thm6SfC9T5sbzRkyxFdo4XbnQvgz5YL36kBnIhEoIgR5UFGBHMH4C+qbQR\n",
    "OK+IchZ9ElBe8gYyrAedmgD96GxH2xAuxAIW0oDgZyZgd71RZ2iBRY322kRJJAdw\n",
    "Xq77qo6eXTKpni7grjpijQKBgDHWRAs5DVeZkTwhoyEW0fRfPKUxZ+ZVwUI9sxCB\n",
    "dt3guKKTtoY5JoOcEyJ9FdBC6TB7rV4KGiSJJf3OXAhgyP9YpNbimbZW52fhzTuZ\n",
    "bwO/ZWC40RKDVZ8f63cNsiGz37XopKvNzu36SJYv7tY8C5WvvLsrd/ZxvIYbRUcf\n",
    "/dgBAoGBAMdR5DXBcOWk3+KyEHXw2qwWcGXyzxtca5SRNLPR2uXvrBYXbhFB/PVj\n",
    "h3rGBsiZbnIvSnSIE+8fFe6MshTl2Qxzw+F2WV3OhhZLLtBnN5qqeSe9PdHLHm49\n",
    "XDce6NV2D1mQLBe8648OI5CScQENuRGxF2/h9igeR4oRRsM1gzJN\n",
    "-----END RSA PRIVATE KEY-----\n",
    "\"\"\".strip()\n",
    "\n",
    "public_key = \"\"\"\n",
    "-----BEGIN PUBLIC KEY-----\n",
    "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzQMcI09qonB9OZT20X3Z\n",
    "/oigSmybR2xfBQ1YJ1oSjQ3YgV+GFUuhEsGDgqt0rok9BreT4toHqniFixddncTH\n",
    "g7EJzU79KZelk2m9I2sEsKUqEsEFlMpkk9qkPHhJB5AQoClOijee7UNOF4yu3HYv\n",
    "GFphwwh4TNJXxkCg69/RsvPBIPi29vXFQzFE7cbN6jSxiCtVrpt/w06jJUsEYgNV\n",
    "QhUFABDyWN4h/67M1eArGA540vydkYdSIEQdknKHjPW62n4dvqDWxtnK0HyChsB+\n",
    "LzmjEnjTJqUzr7kM9Rzq3BY01DNiTVcB2G8t/jICL+TegMGU08ANMKiDfSMGtpz3\n",
    "ZQIDAQAB\n",
    "-----END PUBLIC KEY-----\n",
    "\"\"\".strip()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "If we set __public key__ into model authentication config,\n",
    "then PyGrid will validate that submitted JWT auth token is signed with private key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "server_config[\"authentication\"] = {\n",
    "    \"type\": \"jwt\",\n",
    "    \"pub_key\": public_key,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Now we're ready to host our federated Training Plan!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Host response: {'type': 'model-centric/host-training', 'data': {'status': 'success'}}\n"
     ]
    }
   ],
   "source": [
    "model_params_state = State(\n",
    "    state_placeholders=[\n",
    "        PlaceHolder().instantiate(param)\n",
    "        for param in model_params\n",
    "    ]\n",
    ")\n",
    "\n",
    "response = grid.host_federated_training(\n",
    "    model=model_params_state,\n",
    "    client_plans={'training_plan': training_plan},\n",
    "    client_protocols={},\n",
    "    server_averaging_plan=avg_plan,\n",
    "    client_config=client_config,\n",
    "    server_config=server_config\n",
    ")\n",
    "\n",
    "print(\"Host response:\", response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you see `status: success` this means the plan is successfully hosted in the PyGrid!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check hosted data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "This section is optional, here we just double-check the data is properly hosted in the PyGrid by \"manually\" authenticating, requesting a training cycle and downloading Model and different variants of the Training Plan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper function to make WS requests\n",
    "async def sendWsMessage(data):\n",
    "    async with websockets.connect('ws://' + gridAddress) as websocket:\n",
    "        await websocket.send(json.dumps(data))\n",
    "        message = await websocket.recv()\n",
    "        return json.loads(message)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, create authentication token."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: pyjwt in /home/vova/anaconda3/lib/python3.7/site-packages (1.7.1)\n",
      "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.e30.Cn_0cSjCw1QKtcYDx_mYN_q9jO2KkpcUoiVbILmKVB4LUCQvZ7YeuyQ51r9h3562KQoSas_ehbjpz2dw1Dk24hQEoN6ObGxfJDOlemF5flvLO_sqAHJDGGE24JRE4lIAXRK6aGyy4f4kmlICL6wG8sGSpSrkZlrFLOVRJckTptgaiOTIm5Udfmi45NljPBQKVpqXFSmmb3dRy_e8g3l5eBVFLgrBhKPQ1VbNfRK712KlQWs7jJ31fGpW2NxMloO1qcd6rux48quivzQBCvyK8PV5Sqrfw_OMOoNLcSvzePDcZXa2nPHSu3qQIikUdZIeCnkJX-w0t8uEFG3DfH1fVA\n"
     ]
    }
   ],
   "source": [
    "!pip install pyjwt\n",
    "import jwt\n",
    "auth_token = jwt.encode({}, private_key, algorithm='RS256').decode('ascii')\n",
    "\n",
    "print(auth_token)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make authentication request:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Auth response:  {\n",
      "  \"type\": \"model-centric/authenticate\",\n",
      "  \"data\": {\n",
      "    \"status\": \"success\",\n",
      "    \"worker_id\": \"d64bbaf5-777d-4e9a-bfe9-7b02b407ab99\",\n",
      "    \"requires_speed_test\": true\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "auth_request = {\n",
    "    \"type\": \"model-centric/authenticate\",\n",
    "    \"data\": {\n",
    "        \"model_name\": name,\n",
    "        \"model_version\": version,\n",
    "        \"auth_token\": auth_token,\n",
    "    }\n",
    "}\n",
    "auth_response = await sendWsMessage(auth_request)\n",
    "print('Auth response: ', json.dumps(auth_response, indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make the cycle request:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cycle response: {\n",
      "  \"type\": \"model-centric/cycle-request\",\n",
      "  \"data\": {\n",
      "    \"status\": \"accepted\",\n",
      "    \"request_key\": \"2096a119e7b382383e21daf15b5a10c211add167a49ebcafb1f81fda2c68d850\",\n",
      "    \"version\": \"1.0.0\",\n",
      "    \"model\": \"mnist\",\n",
      "    \"plans\": {\n",
      "      \"training_plan\": 2\n",
      "    },\n",
      "    \"protocols\": {},\n",
      "    \"client_config\": {\n",
      "      \"name\": \"mnist\",\n",
      "      \"version\": \"1.0.0\",\n",
      "      \"batch_size\": 64,\n",
      "      \"lr\": 0.005,\n",
      "      \"max_updates\": 100\n",
      "    },\n",
      "    \"model_id\": 1\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "cycle_request = {\n",
    "    \"type\": \"model-centric/cycle-request\",\n",
    "    \"data\": {\n",
    "        \"worker_id\": auth_response['data']['worker_id'],\n",
    "        \"model\": name,\n",
    "        \"version\": version,\n",
    "        \"ping\": 1,\n",
    "        \"download\": 10000,\n",
    "        \"upload\": 10000,\n",
    "    }\n",
    "}\n",
    "cycle_response = await sendWsMessage(cycle_request)\n",
    "print('Cycle response:', json.dumps(cycle_response, indent=2))\n",
    "\n",
    "worker_id = auth_response['data']['worker_id']\n",
    "request_key = cycle_response['data']['request_key']\n",
    "model_id = cycle_response['data']['model_id'] \n",
    "training_plan_id = cycle_response['data']['plans']['training_plan']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Let's download Model and Training Plan (in various trainslations) and check they are actually workable.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Params shapes: [torch.Size([392, 784]), torch.Size([392]), torch.Size([10, 392]), torch.Size([10])]\n"
     ]
    }
   ],
   "source": [
    "# Model\n",
    "req = requests.get(f\"http://{gridAddress}/model-centric/get-model?worker_id={worker_id}&request_key={request_key}&model_id={model_id}\")\n",
    "model_data = req.content\n",
    "pb = StatePB()\n",
    "pb.ParseFromString(req.content)\n",
    "model_params_downloaded = protobuf.serde._unbufferize(hook.local_worker, pb)\n",
    "print(\"Params shapes:\", [p.shape for p in model_params_downloaded.tensors()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):\n",
      "    var_0 = arg_5.t()\n",
      "    var_1 = arg_1.matmul(var_0)\n",
      "    var_2 = arg_6.add(var_1)\n",
      "    var_3 = var_2.relu()\n",
      "    var_4 = arg_7.t()\n",
      "    var_5 = var_3.matmul(var_4)\n",
      "    var_6 = arg_8.add(var_5)\n",
      "    var_7 = var_6.max()\n",
      "    var_8 = var_6.sub(var_7)\n",
      "    var_9 = var_8.exp()\n",
      "    var_10 = var_9.sum(dim=1, keepdim=True)\n",
      "    var_11 = var_10.log()\n",
      "    var_12 = var_8.sub(var_11)\n",
      "    var_13 = arg_2.mul(var_12)\n",
      "    var_14 = var_13.sum()\n",
      "    var_15 = var_14.neg()\n",
      "    out_1 = var_15.div(arg_3)\n",
      "    var_16 = out_1.mul(0)\n",
      "    var_17 = var_16.add(1)\n",
      "    var_18 = var_17.div(arg_3)\n",
      "    var_19 = var_18.mul(-1)\n",
      "    var_20 = var_19.reshape([-1, 1])\n",
      "    var_21 = var_13.mul(0)\n",
      "    var_22 = var_21.add(1)\n",
      "    var_23 = var_22.mul(var_20)\n",
      "    var_24 = var_23.mul(arg_2)\n",
      "    var_25 = var_24.add(0)\n",
      "    var_26 = var_24.mul(-1)\n",
      "    var_27 = var_26.sum(dim=[1], keepdim=True)\n",
      "    var_28 = var_25.add(0)\n",
      "    var_29 = var_28.add(0)\n",
      "    var_30 = var_28.add(0)\n",
      "    var_31 = var_29.sum(dim=[0])\n",
      "    var_32 = var_31.copy()\n",
      "    var_33 = var_4.t()\n",
      "    var_34 = var_30.matmul(var_33)\n",
      "    var_35 = var_3.t()\n",
      "    var_36 = var_35.matmul(var_30)\n",
      "    var_37 = var_2.mul(0)\n",
      "    var_38 = var_2.__gt__(var_37)\n",
      "    var_39 = var_38.mul(var_34)\n",
      "    var_40 = var_39.add(0)\n",
      "    var_41 = var_39.add(0)\n",
      "    var_42 = var_40.sum(dim=[0])\n",
      "    var_43 = var_42.copy()\n",
      "    var_44 = arg_1.t()\n",
      "    var_45 = var_44.matmul(var_41)\n",
      "    var_46 = var_45.t()\n",
      "    var_47 = var_46.copy()\n",
      "    var_48 = var_36.t()\n",
      "    var_49 = var_48.copy()\n",
      "    var_50 = var_10.__rtruediv__(1)\n",
      "    var_51 = var_27.mul(var_50)\n",
      "    var_52 = var_51.reshape([-1, 1])\n",
      "    var_53 = var_9.mul(0)\n",
      "    var_54 = var_53.add(1)\n",
      "    var_55 = var_54.mul(var_52)\n",
      "    var_56 = var_8.exp()\n",
      "    var_57 = var_55.mul(var_56)\n",
      "    var_58 = var_57.add(0)\n",
      "    var_59 = var_58.add(0)\n",
      "    var_60 = var_58.add(0)\n",
      "    var_61 = var_59.sum(dim=[0])\n",
      "    var_32 = var_32.add_(var_61)\n",
      "    var_62 = var_4.t()\n",
      "    var_63 = var_60.matmul(var_62)\n",
      "    var_64 = var_3.t()\n",
      "    var_65 = var_64.matmul(var_60)\n",
      "    var_66 = var_2.mul(0)\n",
      "    var_67 = var_2.__gt__(var_66)\n",
      "    var_68 = var_67.mul(var_63)\n",
      "    var_69 = var_68.add(0)\n",
      "    var_70 = var_68.add(0)\n",
      "    var_71 = var_69.sum(dim=[0])\n",
      "    var_43 = var_43.add_(var_71)\n",
      "    var_72 = arg_1.t()\n",
      "    var_73 = var_72.matmul(var_70)\n",
      "    var_74 = var_73.t()\n",
      "    var_47 = var_47.add_(var_74)\n",
      "    var_75 = var_65.t()\n",
      "    var_49 = var_49.add_(var_75)\n",
      "    var_76 = arg_4.mul(var_47)\n",
      "    out_3 = arg_5.sub(var_76)\n",
      "    var_77 = arg_4.mul(var_43)\n",
      "    out_4 = arg_6.sub(var_77)\n",
      "    var_78 = arg_4.mul(var_49)\n",
      "    out_5 = arg_7.sub(var_78)\n",
      "    var_79 = arg_4.mul(var_32)\n",
      "    out_6 = arg_8.sub(var_79)\n",
      "    var_80 = torch.argmax(var_6, dim=1)\n",
      "    var_81 = torch.argmax(arg_2, dim=1)\n",
      "    var_82 = var_80.eq(var_81)\n",
      "    var_83 = var_82.sum()\n",
      "    var_84 = var_83.float()\n",
      "    out_2 = var_84.div(arg_3)\n",
      "    return out_1, out_2, out_3, out_4, out_5, out_6\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "# Plan \"list of ops\"\n",
    "req = requests.get(f\"http://{gridAddress}/model-centric/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=list\")\n",
    "pb = PlanPB()\n",
    "pb.ParseFromString(req.content)\n",
    "plan_ops = protobuf.serde._unbufferize(hook.local_worker, pb)\n",
    "print(plan_ops.code)\n",
    "print(plan_ops.torchscript)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):\n",
      "    return out_1, out_2, out_3, out_4, out_5, out_6\n",
      "def forward(self,\n",
      "    argument_1: Tensor,\n",
      "    argument_2: Tensor,\n",
      "    argument_3: Tensor,\n",
      "    argument_4: Tensor,\n",
      "    argument_5: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:\n",
      "  _0, _1, _2, _3, = argument_5\n",
      "  _4 = torch.add(_1, torch.matmul(argument_1, torch.t(_0)), alpha=1)\n",
      "  _5 = torch.relu(_4)\n",
      "  _6 = torch.t(_2)\n",
      "  _7 = torch.add(_3, torch.matmul(_5, _6), alpha=1)\n",
      "  _8 = torch.sub(_7, torch.max(_7), alpha=1)\n",
      "  _9 = torch.exp(_8)\n",
      "  _10 = torch.sum(_9, [1], True, dtype=None)\n",
      "  _11 = torch.sub(_8, torch.log(_10), alpha=1)\n",
      "  _12 = torch.mul(argument_2, _11)\n",
      "  _13 = torch.div(torch.neg(torch.sum(_12, dtype=None)), argument_3)\n",
      "  _14 = torch.add(torch.mul(_13, CONSTANTS.c0), CONSTANTS.c1, alpha=1)\n",
      "  _15 = torch.mul(torch.div(_14, argument_3), CONSTANTS.c2)\n",
      "  _16 = torch.reshape(_15, [-1, 1])\n",
      "  _17 = torch.add(torch.mul(_12, CONSTANTS.c0), CONSTANTS.c1, alpha=1)\n",
      "  _18 = torch.mul(torch.mul(_17, _16), argument_2)\n",
      "  _19 = torch.add(_18, CONSTANTS.c0, alpha=1)\n",
      "  _20 = torch.sum(torch.mul(_18, CONSTANTS.c2), [1], True, dtype=None)\n",
      "  _21 = torch.add(_19, CONSTANTS.c0, alpha=1)\n",
      "  _22 = torch.add(_21, CONSTANTS.c0, alpha=1)\n",
      "  _23 = torch.add(_21, CONSTANTS.c0, alpha=1)\n",
      "  _24 = torch.sum(_22, [0], False, dtype=None)\n",
      "  _25 = torch.matmul(_23, torch.t(_6))\n",
      "  _26 = torch.matmul(torch.t(_5), _23)\n",
      "  _27 = torch.gt(_4, torch.mul(_4, CONSTANTS.c0))\n",
      "  _28 = torch.mul(_27, _25)\n",
      "  _29 = torch.add(_28, CONSTANTS.c0, alpha=1)\n",
      "  _30 = torch.add(_28, CONSTANTS.c0, alpha=1)\n",
      "  _31 = torch.sum(_29, [0], False, dtype=None)\n",
      "  _32 = torch.matmul(torch.t(argument_1), _30)\n",
      "  _33 = torch.t(_32)\n",
      "  _34 = torch.t(_26)\n",
      "  _35 = torch.mul(torch.reciprocal(_10), CONSTANTS.c1)\n",
      "  _36 = torch.reshape(torch.mul(_20, _35), [-1, 1])\n",
      "  _37 = torch.add(torch.mul(_9, CONSTANTS.c0), CONSTANTS.c1, alpha=1)\n",
      "  _38 = torch.mul(torch.mul(_37, _36), torch.exp(_8))\n",
      "  _39 = torch.add(_38, CONSTANTS.c0, alpha=1)\n",
      "  _40 = torch.add(_39, CONSTANTS.c0, alpha=1)\n",
      "  _41 = torch.add(_39, CONSTANTS.c0, alpha=1)\n",
      "  _42 = torch.sum(_40, [0], False, dtype=None)\n",
      "  _43 = torch.add_(_24, _42, alpha=1)\n",
      "  _44 = torch.matmul(_41, torch.t(_6))\n",
      "  _45 = torch.matmul(torch.t(_5), _41)\n",
      "  _46 = torch.gt(_4, torch.mul(_4, CONSTANTS.c0))\n",
      "  _47 = torch.mul(_46, _44)\n",
      "  _48 = torch.add(_47, CONSTANTS.c0, alpha=1)\n",
      "  _49 = torch.add(_47, CONSTANTS.c0, alpha=1)\n",
      "  _50 = torch.sum(_48, [0], False, dtype=None)\n",
      "  _51 = torch.add_(_31, _50, alpha=1)\n",
      "  _52 = torch.matmul(torch.t(argument_1), _49)\n",
      "  _53 = torch.add_(_33, torch.t(_52), alpha=1)\n",
      "  _54 = torch.add_(_34, torch.t(_45), alpha=1)\n",
      "  _55 = torch.sub(_0, torch.mul(argument_4, _53), alpha=1)\n",
      "  _56 = torch.sub(_1, torch.mul(argument_4, _51), alpha=1)\n",
      "  _57 = torch.sub(_2, torch.mul(argument_4, _54), alpha=1)\n",
      "  _58 = torch.sub(_3, torch.mul(argument_4, _43), alpha=1)\n",
      "  _59 = torch.eq(torch.argmax(_7, 1, False), torch.argmax(argument_2, 1, False))\n",
      "  _60 = torch.to(torch.sum(_59, dtype=None), 6, False, False, None)\n",
      "  _61 = (_13, torch.div(_60, argument_3), _55, _56, _57, _58)\n",
      "  return _61\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Plan \"torchscript\"\n",
    "req = requests.get(f\"http://{gridAddress}/model-centric/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=torchscript\")\n",
    "pb = PlanPB()\n",
    "pb.ParseFromString(req.content)\n",
    "plan_ts = protobuf.serde._unbufferize(hook.local_worker, pb)\n",
    "print(plan_ts.code)\n",
    "print(plan_ts.torchscript.code)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):\n",
      "    var_0 = tf.transpose(arg_5)\n",
      "    var_1 = tf.matMul(arg_1, var_0)\n",
      "    var_2 = tf.add(arg_6, var_1)\n",
      "    var_3 = tf.relu(var_2)\n",
      "    var_4 = tf.transpose(arg_7)\n",
      "    var_5 = tf.matMul(var_3, var_4)\n",
      "    var_6 = tf.add(arg_8, var_5)\n",
      "    var_7 = tf.max(var_6)\n",
      "    var_8 = tf.sub(var_6, var_7)\n",
      "    var_9 = tf.exp(var_8)\n",
      "    var_10 = tf.sum(var_9, 1, keepdim=True)\n",
      "    var_11 = tf.log(var_10)\n",
      "    var_12 = tf.sub(var_8, var_11)\n",
      "    var_13 = tf.mul(arg_2, var_12)\n",
      "    var_14 = tf.sum(var_13)\n",
      "    var_15 = tf.neg(var_14)\n",
      "    out_1 = tf.div(var_15, arg_3)\n",
      "    var_16 = tf.mul(out_1, 0)\n",
      "    var_17 = tf.add(var_16, 1)\n",
      "    var_18 = tf.div(var_17, arg_3)\n",
      "    var_19 = tf.mul(var_18, -1)\n",
      "    var_20 = tf.layers.reshape(var_19, [-1, 1])\n",
      "    var_21 = tf.mul(var_13, 0)\n",
      "    var_22 = tf.add(var_21, 1)\n",
      "    var_23 = tf.mul(var_22, var_20)\n",
      "    var_24 = tf.mul(var_23, arg_2)\n",
      "    var_25 = tf.add(var_24, 0)\n",
      "    var_26 = tf.mul(var_24, -1)\n",
      "    var_27 = tf.sum(var_26, [1], keepdim=True)\n",
      "    var_28 = tf.add(var_25, 0)\n",
      "    var_29 = tf.add(var_28, 0)\n",
      "    var_30 = tf.add(var_28, 0)\n",
      "    var_31 = tf.sum(var_29, [0])\n",
      "    var_32 = clone(var_31)\n",
      "    var_33 = tf.transpose(var_4)\n",
      "    var_34 = tf.matMul(var_30, var_33)\n",
      "    var_35 = tf.transpose(var_3)\n",
      "    var_36 = tf.matMul(var_35, var_30)\n",
      "    var_37 = tf.mul(var_2, 0)\n",
      "    var_38 = tf.greater(var_2, var_37)\n",
      "    var_39 = tf.mul(var_38, var_34)\n",
      "    var_40 = tf.add(var_39, 0)\n",
      "    var_41 = tf.add(var_39, 0)\n",
      "    var_42 = tf.sum(var_40, [0])\n",
      "    var_43 = clone(var_42)\n",
      "    var_44 = tf.transpose(arg_1)\n",
      "    var_45 = tf.matMul(var_44, var_41)\n",
      "    var_46 = tf.transpose(var_45)\n",
      "    var_47 = clone(var_46)\n",
      "    var_48 = tf.transpose(var_36)\n",
      "    var_49 = clone(var_48)\n",
      "    var_50 = tf.div(1, var_10)\n",
      "    var_51 = tf.mul(var_27, var_50)\n",
      "    var_52 = tf.layers.reshape(var_51, [-1, 1])\n",
      "    var_53 = tf.mul(var_9, 0)\n",
      "    var_54 = tf.add(var_53, 1)\n",
      "    var_55 = tf.mul(var_54, var_52)\n",
      "    var_56 = tf.exp(var_8)\n",
      "    var_57 = tf.mul(var_55, var_56)\n",
      "    var_58 = tf.add(var_57, 0)\n",
      "    var_59 = tf.add(var_58, 0)\n",
      "    var_60 = tf.add(var_58, 0)\n",
      "    var_61 = tf.sum(var_59, [0])\n",
      "    var_32 = tf.add(var_32, var_61)\n",
      "    var_62 = tf.transpose(var_4)\n",
      "    var_63 = tf.matMul(var_60, var_62)\n",
      "    var_64 = tf.transpose(var_3)\n",
      "    var_65 = tf.matMul(var_64, var_60)\n",
      "    var_66 = tf.mul(var_2, 0)\n",
      "    var_67 = tf.greater(var_2, var_66)\n",
      "    var_68 = tf.mul(var_67, var_63)\n",
      "    var_69 = tf.add(var_68, 0)\n",
      "    var_70 = tf.add(var_68, 0)\n",
      "    var_71 = tf.sum(var_69, [0])\n",
      "    var_43 = tf.add(var_43, var_71)\n",
      "    var_72 = tf.transpose(arg_1)\n",
      "    var_73 = tf.matMul(var_72, var_70)\n",
      "    var_74 = tf.transpose(var_73)\n",
      "    var_47 = tf.add(var_47, var_74)\n",
      "    var_75 = tf.transpose(var_65)\n",
      "    var_49 = tf.add(var_49, var_75)\n",
      "    var_76 = tf.mul(arg_4, var_47)\n",
      "    out_3 = tf.sub(arg_5, var_76)\n",
      "    var_77 = tf.mul(arg_4, var_43)\n",
      "    out_4 = tf.sub(arg_6, var_77)\n",
      "    var_78 = tf.mul(arg_4, var_49)\n",
      "    out_5 = tf.sub(arg_7, var_78)\n",
      "    var_79 = tf.mul(arg_4, var_32)\n",
      "    out_6 = tf.sub(arg_8, var_79)\n",
      "    var_80 = tf.argMax(var_6, 1)\n",
      "    var_81 = tf.argMax(arg_2, 1)\n",
      "    var_82 = tf.equal(var_80, var_81)\n",
      "    var_83 = tf.sum(var_82)\n",
      "    var_84 = tf.cast(var_83, float32)\n",
      "    out_2 = tf.div(var_84, arg_3)\n",
      "    return out_1, out_2, out_3, out_4, out_5, out_6\n"
     ]
    }
   ],
   "source": [
    "# Plan \"tfjs\"\n",
    "req = requests.get(f\"http://{gridAddress}/model-centric/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=tfjs\")\n",
    "pb = PlanPB()\n",
    "pb.ParseFromString(req.content)\n",
    "plan_tfjs = protobuf.serde._unbufferize(hook.local_worker, pb)\n",
    "print(plan_tfjs.code)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step 5: Train\n",
    "\n",
    "To train hosted model, use one of the existing FL workers:\n",
    " * PySyft - see the \"[Part 02 - Execute Plan](Part%2002%20-%20Execute%20Plan.ipynb)\" notebook that\n",
    "has example of using Python FL worker.\n",
    " * [SwiftSyft](https://github.com/OpenMined/SwiftSyft)\n",
    " * [KotlinSyft](https://github.com/OpenMined/KotlinSyft)\n",
    " * [syft.js](https://github.com/OpenMined/syft.js)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}